import torch
import math
from torch import nn
from image_synthesis.modeling.codecs.image_codec.openai_dvae import OpenAIDiscreteVAE
from image_synthesis.utils.misc import instantiate_from_config


class DALLE(nn.Module):
    def __init__(
        self,
        *,
        content_info={'key': 'image'},
        condition_info={'key': 'text'},
        content_codec_config,
        condition_codec_config,
        transformer_config
    ):
        super().__init__()
        self.content_info = content_info
        self.condition_info = condition_info
        self.content_codec = instantiate_from_config(content_codec_config)
        self.condition_codec = instantiate_from_config(condition_codec_config)
        self.transformer = instantiate_from_config(transformer_config)

    def parameters(self, recurse=True, name=None):
        # return super().parameters(recurse=True)
        if name is None or name == 'none':
            return super().parameters(recurse=recurse)
        else:
            names = name.split('+')
            params = []
            for n in names:
                try: # the parameters() method is not overwritten for some classes
                    params += getattr(self, name).parameters(recurse=recurse, name=name)
                except:
                    params += getattr(self, name).parameters(recurse=recurse)
            return params

    @property
    def device(self):
        return self.transformer.device

    def get_ema_model(self):
        return self.transformer

    def prepare_condition(self, batch, condition=None):
        cond_key = self.condition_info['key']
        cond = batch[cond_key] if condition is None else condition
        if torch.is_tensor(cond):
            cond = cond.to(self.device)
        cond = self.condition_codec.get_tokens(cond)
        cond_ = {}
        for k, v in cond.items():
            v = v.to(self.device) if torch.is_tensor(v) else v
            cond_['condition_' + k] = v
        return cond_

    def prepare_content(self, batch, with_mask=False):
        cont_key = self.content_info['key']
        cont = batch[cont_key]
        if torch.is_tensor(cont):
            cont = cont.to(self.device)
        if not with_mask:
            cont = self.content_codec.get_tokens(cont)
        else:
            mask = batch['mask'.format(cont_key)]
            cont = self.content_codec.get_tokens(cont, mask, enc_with_mask=False)
        cont_ = {}
        for k, v in cont.items():
            v = v.to(self.device) if torch.is_tensor(v) else v
            cont_['content_' + k] = v
        return cont_
    
    def prepare_input(self, batch):
        input = self.prepare_condition(batch)
        input.update(self.prepare_content(batch))
        return input

    def half(self):

        for c in self.children():
            if isinstance(c, OpenAIDiscreteVAE):
                c.half
            else:
                c._apply(lambda t: t.half() if t.is_floating_point() else t)

        return self


    @torch.no_grad()
    def generate_content(
        self,
        *,
        batch,
        condition=None,
        filter_ratio = 0.5,
        temperature = 1.0,
        content_ratio = 0.0,
        replicate=1,
        return_att_weight=False,
    ):
        self.eval()
        if condition is None:
            condition = self.prepare_condition(batch=batch)
        else:
            condition = self.prepare_condition(batch=None, condition=condition)
        
        if content_ratio > 0:
            content = self.prepare_content(batch=batch)
        else:
            content = None

        if replicate != 1:
            for k in condition.keys():
                if condition[k] is not None:
                    condition[k] = torch.cat([condition[k] for _ in range(replicate)], dim=0)
            
            if content is not None:
                for k in content.keys():
                    if content[k] is not None:
                        content[k] = torch.cat([content[k] for _ in range(replicate)], dim=0)
        
        if content_ratio > 0:
            content_len = content['content_token'].shape[1]
            keep_len = int(content_len * content_ratio)
            content_token = content['content_token'][:, :keep_len]
        else:
            content_token = None

        trans_out = self.transformer.sample(condition_token=condition['condition_token'],
                                                condition_mask=condition.get('condition_mask', None) ,
                                                content_token=content_token,
                                                filter_ratio=filter_ratio,
                                                temperature=temperature,
                                                return_att_weight=return_att_weight)
        # import pdb; pdb.set_trace()
        content = self.content_codec.decode(trans_out['content_token'])
        self.train()
        out = {
            'content': content
        }
        if return_att_weight:
            out['condition_attention'] = trans_out['condition_attention']
            content_att = trans_out['content_attention']
            shape = *content_att.shape[:-1], self.content_codec.token_shape[0], self.content_codec.token_shape[1]
            out['content_attention'] = content_att.view(*shape) # B x Lt x Lt -> B x Lt x H x W


            # concat the condition before the content
            cond_length = math.ceil(out['condition_attention'].shape[-1] /float(out['content_attention'].shape[-1])) * out['content_attention'].shape[-1]
            cond_length = cond_length + out['content_attention'].shape[-1] # add for row for better visualization
            att_cond = torch.zeros(*out['condition_attention'].shape[:2], cond_length).to(out['condition_attention'])
            att_cond[:, :, :out['condition_attention'].shape[-1]] = out['condition_attention']
            att_cond = att_cond.view(*out['condition_attention'].shape[:2], -1, out['content_attention'].shape[-1])

            out['content_attention'] = torch.cat([att_cond, out['content_attention']], dim=-2)
            

        return out

    @torch.no_grad()
    def reconstruct(
        self,
        input
    ):
        if torch.is_tensor(input):
            input = input.to(self.device)
        cont = self.content_codec.get_tokens(input)
        cont_ = {}
        for k, v in cont.items():
            v = v.to(self.device) if torch.is_tensor(v) else v
            cont_['content_' + k] = v
        rec = self.content_codec.decode(cont_['content_token'])
        return rec

    @torch.no_grad()
    def sample(
        self,
        batch,
        clip = None,
        temperature = 1.,
        return_rec = True,
        filter_ratio = [0.5, 0.7, 1], # the ratios to filter the logits before sampling the logits
        content_ratio = [0, 0.5], # the ratio to keep the encoded content tokens
        return_att_weight=False,
        return_logits=False,
        **kwargs,
    ):
        self.eval()
        condition = self.prepare_condition(batch)
        content = self.prepare_content(batch)

        content_samples = {'input_image': batch[self.content_info['key']]}
        if return_rec:
            content_samples['reconstruction_image'] = self.content_codec.decode(content['content_token'])  

        # import pdb; pdb.set_trace()

        for fr in filter_ratio:
            for cr in content_ratio:
                num_content_tokens = int((content['content_token'].shape[1] * cr))
                if num_content_tokens < 0 or num_content_tokens >= content['content_token'].shape[-1]:
                    continue
                else:
                    content_token = content['content_token'][:, :num_content_tokens]
                trans_out = self.transformer.sample(condition_token=condition['condition_token'],
                                                        condition_mask=condition.get('condition_mask', None),
                                                        content_token=content_token,
                                                        filter_ratio=fr,
                                                        temperature=temperature,
                                                        return_att_weight=return_att_weight,
                                                        return_logits=return_logits,
                                                        **kwargs)

                content_samples['cond1_cont{}_fr{}_image'.format(cr, fr)] = self.content_codec.decode(trans_out['content_token'])

                if return_att_weight:
                    content_samples['cond1_cont{}_fr{}_image_condition_attention'.format(cr, fr)] = trans_out['condition_attention'] # B x Lt x Ld
                    content_att = trans_out['content_attention']
                    shape = *content_att.shape[:-1], self.content.token_shape[0], self.content.token_shape[1]
                    content_samples['cond1_cont{}_fr{}_image_content_attention'.format(cr, fr)] = content_att.view(*shape) # B x Lt x Lt -> B x Lt x H x W
                if return_logits:
                    content_samples['logits'] = trans_out['logits']
        self.train() 
        output = {'condition': batch[self.condition_info['key']]}   
        output.update(content_samples)
        return output

    def forward(
        self,
        batch,
        name='none',
        **kwargs
    ):
        input = self.prepare_input(batch)
        output = self.transformer(input, **kwargs)
        return output
